# ================================ 2.转换数据格式 ==============================================
from torch.utils.data import Dataset
from datasets import load_dataset, load_from_disk


class MyDataset(Dataset):
    def __init__(self, split):
        """
        初始化数据集
        :param split:
        """
        # 从磁盘加载数据
        self.dataset = load_from_disk(r"D:\code\base_learn\learn_ai\hugging_face\saved_dataset")
        if split == "train":
            self.dataset = self.dataset["train"]
        elif split == "test":
            self.dataset = self.dataset["test"]
        elif split == "validation":
            self.dataset = self.dataset["validation"]
        else:
            raise ValueError("Invalid split")

    def __len__(self):
        """
        返回数据集长度
        :return:
        """
        return len(self.dataset)

    def __getitem__(self, index):
        """
        对每条数据单独做处理
        :param index:
        :return:
        """
        my_text = self.dataset[index]["text"]
        my_label = self.dataset[index]["label"]
        return my_text, my_label


if __name__ == '__main__':
    dataset = MyDataset("test")
    for text, label in dataset:
        print(text, label)